import numpy as np
import h5py
import os
import torch
import torch.nn as nn
from util.trainer.model import build_model
from util.trainer.get_dataset import get_dataset
from util.trainer.get_dataloader import  get_nonsampler_trainloader
from util.trainer.get_criterion import  get_criterion

def get_unplotted_indices_and_plotted_losses(surf_file, xcoordinates, ycoordinates, rank, args):
    f = h5py.File(surf_file, 'r')
    if rank==0 or (rank is None):
        total_num=len(xcoordinates)*len(ycoordinates)
        print(f"Finding {total_num} keys from {len(f.keys())}.")
        count=0
    losses = -np.ones((len(xcoordinates), len(ycoordinates)))
    for i, x in enumerate(xcoordinates):
        for j, y in enumerate(ycoordinates):
            key = f"{x}_{y}"
            if key in f.keys():
                losses[i, j] = f[key][()].item()
                if rank==0 or (rank is None):
                    count+=1
    f.close()
    if rank==0 or (rank is None):
        print(f"We find {count} points directly, {total_num-count} points required to be evaluated.Every process needs to evaluate {(total_num-count)//args.world_size} or {(total_num-count)//args.world_size+1} points.")
    inds = np.flatnonzero(losses == -1) #why this is only processed on rank 0????
    x_inds, y_inds = np.unravel_index(inds, losses.shape)
    coordinates = np.column_stack((xcoordinates[x_inds], ycoordinates[y_inds]))
    
    return inds, coordinates, losses


def split_inds(num_inds, worldsize):
    chunk = num_inds // worldsize
    remainder = num_inds % worldsize
    if remainder==0:
        max_tasks_per_gpu = chunk
    else:
        max_tasks_per_gpu = chunk+1
    splitted_idx = []
    for rank in range(0, worldsize):
        # start_idx = rank * chunk + min(rank, remainder)
        # stop_idx = start_idx + chunk + (rank < remainder)
        start_idx = rank * chunk + max(0, rank - (worldsize - remainder))
        stop_idx = start_idx + chunk + (rank >= worldsize - remainder)
        splitted_idx.append(range(start_idx, stop_idx))
    return splitted_idx, int(max_tasks_per_gpu), remainder


def get_job_indices(args, inds, total_coords, rank):
    splitted_idx, max_tasks_per_proc, remainder = split_inds(len(inds), args.world_size)
    padded_inds = np.full(max_tasks_per_proc, np.nan)
    padded_inds[:len(inds[splitted_idx[rank]])] = inds[splitted_idx[rank]]
    padded_proc_coords = np.full((max_tasks_per_proc, 2), (np.nan, np.nan))
    padded_proc_coords[:len(total_coords[splitted_idx[rank]])] = total_coords[splitted_idx[rank]]
    inds_nums = [len(idx) for idx in splitted_idx]

    return padded_inds, padded_proc_coords, inds_nums, max_tasks_per_proc, remainder 

class ModelParallelScheduler:
    def __init__(self, local_rank, rank, args, num_tasks, max_memory_utilization=0.9):
        """
        初始化模型并行调度器
        :param device: GPU设备，如 'cuda:0'
        :param max_memory_utilization: 最大显存利用率（0-1之间）
        :param min_batch_size: 最小的 batch size，避免过小的 batch size
        """
        self.device = f"cuda:{local_rank}"
        self.max_memory_utilization = max_memory_utilization
        self.min_batch_size = args.min_batch_size
        self.batchsize = args.batch_size_per_gpu
        self.streamcount = args.num_streams
        self.model = args.model
        self.num_tasks = num_tasks
        self.rank = rank
        self.dataset = get_dataset(args, testset=False)
        self.criterion = get_criterion(args)

    def get_free_memory(self):
        """获取GPU的剩余显存"""
        free_memory = torch.cuda.memory_allocated(self.device)
        total_memory = torch.cuda.get_device_properties(self.device).total_memory
        return total_memory - free_memory

    def find_optimal(self, args, test_func, min_value, max_value, use_doubling=True):
        """
        查找最佳的值（可以是batch_size或model_count）
        :param test_func: 用于测试当前值是否可用的测试函数
        :param min_value: 测试值的最小范围
        :param max_value: 测试值的最大范围
        :param models: 模型列表
        :param dataloader: 输入生成函数
        :param use_doubling: 是否使用倍增法
        """
        low, high = min_value, max_value
        # 使用倍增法估算最大值
        if use_doubling:
            if low==1:
                low = 2
            while low <= max_value and test_func(args, low):
                optimal_value = low
                low *= 2
            if low>max_value:
                return max_value
            else:
                high = min(low-1, max_value) 
                low //= 2 
        optimal_value = low
        # 二分法寻找最优值
        if low==1:
            low=2
        while low <= high:
            mid = (low + high) // 2
            if test_func(args, mid):
                optimal_value = mid
                low = mid + 1
            else:
                high = mid - 1
        return optimal_value

    def test_batch_size(self, args, batch_size):
        """测试在给定 batch_size 下，是否能加载所有模型"""
        container_batch_size = args.batch_size_per_gpu
        args.batch_size_per_gpu = batch_size
        try:
            if self.rank==0 or (self.rank is None):
                print(f"Testing batchsize {args.batch_size_per_gpu}...")
            input, target = next(iter(get_nonsampler_trainloader(args, self.dataset)))
            args.batch_size_per_gpu = container_batch_size
            input_data = input.to(self.device)
            target=target.to(self.device)
            model = build_model(args).to(self.device)
            output = model(input_data)
            self.criterion(output, target)
            return True
        except RuntimeError as e:
            args.batch_size_per_gpu = container_batch_size
            if "out of memory" in str(e):
                return False
            raise e

    def test_model_count(self, args, model_count):
        """测试在给定 batch_size 和模型数量下，是否能加载所有模型"""
        try:
            if self.rank==0 or (self.rank is None):
                print(f"Testing collocating {model_count} models on each GPU...")
            models = [build_model(args).to(self.device) for _ in range(model_count)]  
            inputs, targets= next(iter(get_nonsampler_trainloader(args, self.dataset))) 
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            streams = [torch.cuda.Stream() for _ in range(model_count)]  
            output_list = [None] * model_count 
            for i in range(model_count):
                with torch.cuda.stream(streams[i]):
                    models[i].eval()
                    with torch.no_grad():
                        outputs = models[i](inputs)
                        loss = self.criterion(outputs, targets)
                        output_list[i] = loss
            for stream in streams:
                stream.synchronize()
            return True
        except RuntimeError as e:
            if "out of memory" in str(e):
                return False
            raise e

    def schedule(self, args):
        # 使用 find_optimal 来寻找最大 batch_size
        if args.adaptive_batchsize_and_streamcount :
            self.batchsize = self.find_optimal(args, self.test_batch_size, self.min_batch_size, len(self.dataset))
            args.batch_size_per_gpu = self.batchsize

            # 使用找到的最大 batch_size 来确定最大模型数量
            if self.batchsize==len(self.dataset) and self.num_tasks>=1:
                self.streamcount = self.find_optimal(args, self.test_model_count, 1, self.num_tasks)
            else:
                self.streamcount = 1
        if self.rank==0 or (self.rank is None):
            print(f"Batch size per stream for evaluation: {self.batchsize}")
            print(f"Number of models loaded at once per GPU: {self.streamcount}")







